-
Notifications
You must be signed in to change notification settings - Fork 20
Conversation
146c6a1
to
5fb21a4
Compare
5fb21a4
to
f8b8d89
Compare
float8_experimental/float8_tensor.py
Outdated
# return Float8Tensor(tensors["_data"], tensors["_scale"], metadatas[0]) | ||
return Float8Tensor(inner_tensors["_data"], metadata["_scale"], metadata["_orig_dtype"]) | ||
assert len(inner_tensors) == 2 | ||
return Float8Tensor(inner_tensors["_data"], inner_tensors["_scale"], metadata["_orig_dtype"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@bdhirsh okay I had a conjecture that the way we were using subclasses was not really right, so with this +
commenting out these two lines:
https://github.com/pytorch/pytorch/blob/main/torch/fx/experimental/symbolic_shapes.py#L2791-L2792
I think I am getting somewhere. I had to comment these out since "_data" is not the same shape as the scaler tensor "_scale"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah yep this is exactly the dynamic shapes issue that I hit with vasiliy. So this is probably hi-pri to fix.
Right now there’s a constraint with dynamic shapes and wrapper subclasses that we expect your inner tensors to have the same dynamic-ness as the actually wrapper’s shape. Instead, dynamo should try to track the dynamic-ness of each inner tensor individually. And in the FP8 case, it should recognize that the scale is static and doesn’t change sizes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pretty cool!
47b2bf6
to
4459b89
Compare
5b15bb2
to
9b3ee33
Compare
2e94b6e
to
c915267
Compare
61791d8
to
007e00d
Compare
|
||
def to_original_precision(self): | ||
return FromFloat8ConstrFunc.apply(self) | ||
|
||
@staticmethod | ||
@torch._dynamo.allow_in_graph |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zou3519 without allow_in_graph this breaks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the long-term correct thing here is also:
(1) to_float()
remains a (static?) method on the subclass
(2) dynamo eventually understands that it should proxy custom methods on "traceable subclasses" directly into the torch graph
That way dynamo won't try to inline into to_float8()
here and try to carefully understand the custom autograd.Function you wrote.
@@ -5,6 +5,7 @@ set -e | |||
|
|||
pytest test/test_base.py | |||
pytest test/test_sam.py | |||
pytest test/test_compile.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I absolutely hate that internal keeps changing the GD file execution permissions on the shell scripts
y = torch.matmul(x_fp8, w_fp8.t()) | ||
|
||
# Cast gradY to float8_e5m2 during backward | ||
y = self.cast_y_to_float8_in_bw(y, self.emulate) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mentioned offline but food for thought that I'll mention here: it would be interesting to think about what it would take to have aten.matmul(Float8Tensor, Float8Tensor)
actually return another Float8Tensor
, and then leave it to the subclass to know to upcast on future ops that don't want to handle Float8 directly.
My understanding was:
(1) This is a pain mostly because the extra buffers for float8 live directly on the Float8Linear
nn module today and not the subclass (probably for good reason)
(2) Doing this would provide benefit if we want to start increasing the number of ops that directly handle float8, but all we care about is linear then this generality is probably not very useful.
def add_weight_tag(self): | ||
# We add a tag to the weight nn.Parameter in order to signal | ||
# To FSDP that this param is a weight | ||
self.weight._is_fp8_weight = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm reviewing the subclass changes but probably not the right person to review this one
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was added in a previous PR and just moved it to the Mixin so that it can be added to the TP stuff
output_dtype = a._orig_dtype | ||
if a._emulate: | ||
assert a._emulate == b._emulate | ||
return torch.ops.aten.mm_float8_emulated( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh also just thinking - should emulate just be a global config somewhere, instead of a flag that you have to plumb around?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Talked about this with Brian offline. This is probably right, but I am going to do this in a followup. I also want to see if when get plain torch.nn.fucntional.linear in the LinearFloat8 and will do some matmul changes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the subclass changes lgtm! (I didn't fully read the non-subclass changes)
@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Summary
We use the dispatching mechanism to mm
TODO
Note
Vasiliy has already started this here:
#28
Some things have changed though since then, we are outputing by default in higher precision. However I still need to replicate the amax_buffer filling here and store on float8_tensor passed in
Corresponding core changes to get as far as possible in compile for aot_eager
pytorch/pytorch#111735
Current Compile Progress